# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Variations in 3D rotations."""


from dm_control.composer.variation import base
from dm_control.composer.variation import variation_values
from dm_control.utils import transformations
import numpy as np

IDENTITY_QUATERNION = np.array([1., 0., 0., 0.])


class UniformQuaternion(base.Variation):
  """Uniformly distributed unit quaternions."""

  def __call__(self, initial_value=None, current_value=None, random_state=None):
    random_state = random_state or np.random
    u1, u2, u3 = random_state.uniform([0.] * 3, [1., 2. * np.pi, 2. * np.pi])
    return np.array([np.sqrt(1. - u1) * np.sin(u2),
                     np.sqrt(1. - u1) * np.cos(u2),
                     np.sqrt(u1) * np.sin(u3),
                     np.sqrt(u1) * np.cos(u3)])

  def __eq__(self, other):
    return isinstance(other, UniformQuaternion)

  def __repr__(self):
    return "UniformQuaternion()"


class QuaternionFromAxisAngle(base.Variation):
  """Quaternion variation specified in terms of variations in axis and angle."""

  def __init__(self, axis, angle):
    self._axis = axis
    self._angle = angle

  def __call__(self, initial_value=None, current_value=None, random_state=None):
    random_state = random_state or np.random
    axis = variation_values.evaluate(
        self._axis, initial_value, current_value, random_state)
    angle = variation_values.evaluate(
        self._angle, initial_value, current_value, random_state)
    return transformations.axisangle_to_quat(np.asarray(axis) * angle)

  def __eq__(self, other):
    if not isinstance(other, QuaternionFromAxisAngle):
      return False
    return (
        self._axis == other._axis
        and self._angle == other._angle
    )

  def __repr__(self):
    return (
        f"QuaternionFromAxisAngle(axis={self._axis}, angle={self._angle})"
    )


class QuaternionPreMultiply(base.Variation):
  """A variation that pre-multiplies an existing quaternion value.

  This variation takes a quaternion value generated by another variation and
  pre-multiplies it to an existing value. In cumulative mode, the new quaternion
  is pre-multiplied to the current value being varied. In non-cumulative mode,
  the new quaternion is pre-multiplied to a fixed initial value.
  """

  def __init__(self, quat, cumulative=False):
    self._quat = quat
    self._cumulative = cumulative

  def __call__(self, initial_value=None, current_value=None, random_state=None):
    random_state = random_state or np.random
    q1 = variation_values.evaluate(self._quat, initial_value, current_value,
                                   random_state)
    q2 = current_value if self._cumulative else initial_value
    return transformations.quat_mul(np.asarray(q1), np.asarray(q2))

  def __eq__(self, other):
    if not isinstance(other, QuaternionPreMultiply):
      return False
    return self._quat == other._quat and self._cumulative == other._cumulative

  def __repr__(self):
    return (
        f"QuaternionPreMultiply(quat={self._quat},"
        f" cumulative={self._cumulative})"
    )


class QuaternionRotate(base.Variation):
  """Variation that rotates a given vector by the given quaternion.

  The vector can either be an existing value passed at evaluation, or specified
  as a separate variation at construction. In the former case, cumulative mode
  determines whether to use the current or initial value of the vector. The#
  quaternion is always specified by a variation at construction.
  """

  def __init__(self, quat, vec=None, cumulative=False):
    self._quat = quat
    self._vec = vec
    self._cumulative = cumulative

  def __call__(self, initial_value=None, current_value=None, random_state=None):
    random_state = random_state or np.random
    quat = variation_values.evaluate(
        self._quat, initial_value, current_value, random_state
    )
    if self._vec is None:
      vec = current_value if self._cumulative else initial_value
    else:
      vec = variation_values.evaluate(
          self._vec, initial_value, current_value, random_state
      )
    return transformations.quat_rotate(np.asarray(quat), np.asarray(vec))

  def __eq__(self, other):
    if not isinstance(other, QuaternionRotate):
      return False
    return (
        self._quat == other._quat
        and self._vec == other._vec
        and self._cumulative == other._cumulative
    )

  def __repr__(self):
    return (
        f"QuaternionRotate(quat={self._quat}, vec={self._vec},"
        f" cumulative={self._cumulative})"
    )
